#!/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
import os
import tensorflow as tf
import matplotlib.pyplot as plt

# Reset the default graph for rerun notebook
tf.reset_default_graph()

# Reset the random seed for reproducibility
np.random.seed(42)
tf.set_random_seed(42)

# 读取 MNIST 数据
from tensorflow.examples.tutorials.mnist import input_data
data_path = os.path.join('.', 'temp', 'data')
mnist = input_data.read_data_sets(data_path,one_hot=True)
#每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size


X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name="X")
y = tf.placeholder(shape=[None], dtype=tf.int64, name="y")

# 第一层是普通的CNN卷积层
conv1_params = {
"filters": 256,
"kernel_size": 9,
"strides": 1,
"padding": "valid",
"activation": tf.nn.relu,
}
conv1 = tf.layers.conv2d(X, name="conv1", **conv1_params)
print('conv1===========',conv1)
# Primary Capsules
caps1_n_maps = 32
caps1_n_dims = 8
conv2_params = {
"filters": caps1_n_maps * caps1_n_dims,
"kernel_size": 9,
"strides": 2,
"padding": "valid",
"activation": tf.nn.relu
}
conv2 = tf.layers.conv2d(conv1, name="conv2", **conv2_params)
print('conv2===========',conv2)
caps1_n_caps = caps1_n_maps * 6* 6
caps1_raw = tf.reshape(conv2, [-1, caps1_n_caps, caps1_n_dims],name="caps1_raw")

# 定义压缩函数 squash
def squash(s, axis=-1, epsilon=1e-7, name=None):
    with tf.name_scope(name, default_name="squash"):
        squared_norm = tf.reduce_sum(tf.square(s), axis=axis,  keep_dims=True)
        safe_norm = tf.sqrt(squared_norm + epsilon)
        squash_factor = squared_norm / (1. + squared_norm)
        unit_vector = s / safe_norm
    return squash_factor * unit_vector
# 这里有个技巧，在分母 ||s|| 里面加入小量10-7，防止分母为零。
# 最后用 squash 函数将 caps1_raw单位化得到 cap1_output。它的 shape 也是 [?, 1152, 8]
caps1_output = squash(caps1_raw, name="caps1_output")

# Digit Capsules
caps2_n_caps = 10
caps2_n_dims = 16
init_sigma = 0.01
W_init = tf.random_normal(
shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims),
stddev=init_sigma, dtype=tf.float32, name="W_init")
W = tf.Variable(W_init, name="W")
# batch_size = tf.shape(X)[0]
W_tiled = tf.tile(W, [batch_size, 1, 1, 1, 1], name="W_tiled")
# 首先定义一个四维随机变量 W_init，当 W的初始值，它的 shape 是 [1152, 10, 16, 8]，
# batch_size 是一批图片的个数。tile 函数实际就是将 W复制了batch_size个，
# 储存在 W_tiled，它的 shape 是 [?, 1152, 10,16, 8]

caps1_output_expanded = tf.expand_dims(caps1_output, -1, name="caps1_output_expanded")
caps1_output_tile = tf.expand_dims(caps1_output_expanded, 2, name="caps1_output_tile")
caps1_output_tiled = tf.tile(caps1_output_tile, [1, 1, caps2_n_caps, 1, 1], name="caps1_output_tiled")
# 首先看最终想要的结果的 shape 是 [?, 1152, 10, 8, 1]，而 caps1_output 的 shape 是 [?, 1152, 8]
# 需要在最后的 axis 上扩张一维，用 expand_dims 函数和参数 -1，得到 caps1_output_expanded 的 shape 是 [?, 1152, 8, 1]
# 需要在第二个 axis 上扩张一维，用 expand_dims 函数和参数 2，得到 caps1_output_tile 的 shape 是 [?, 1152, 1, 8, 1]
# 用 tile 函数将第三个 axis 上复制 10 个，得到 caps1_output_tiled 的 shape 是 [?, 1152, 10, 8, 1]

# 定义数组 u_hat
caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled, name="caps2_predicted")

# 动态路由
# 第一轮初始化 b
b = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1], dtype=np.float32, name="raw_weights")
# b的 shape 为 [?, 1152, 10, 1, 1]

# 第一轮初始化 c
c = tf.nn.softmax(b, dim=2, name="routing_weights")
# c 的 shape 为 [?, 1152, 10, 1, 1]，而且在第二个 axis 上做归一化，
# 原因就是每一个 caps1 到所有 caps2 的概率总和为一

# 第一轮计算 s和v
weighted_predictions = tf.multiply(c, caps2_predicted,name="weighted_predictions")
s = tf.reduce_sum(weighted_predictions, axis=1,keep_dims=True, name="weighted_sum")
v = squash(s, axis=-2, name="caps2_output_round_1")
# weighted_predictions 的 shape 为 [?, 1152, 10, 16, 1]，而 s和v的 shape 为 [?, 1, 10, 16, 1]，
# 因为在第一个 axis 上用 reduce_sum 函数求和再用 squash 函数压缩

# 第二轮迭代
v_tiled = tf.tile(v, [1, caps1_n_caps, 1, 1, 1],name="caps2_output_round_1_tiled")
agreement = tf.matmul(caps2_predicted, v_tiled,transpose_a=True, name="agreement")
b= tf.add(b, agreement, name="raw_weights_round_2")
c= tf.nn.softmax(b, dim=2, name="routing_weights_round_2")
weighted_predictions = tf.multiply(c, caps2_predicted,name="weighted_predictions_round_2")
s = tf.reduce_sum(weighted_predictions, axis=1,keep_dims=True, name="weighted_sum_round_2")
v = squash(s, axis=-2, name="caps2_output_round_2")

# 第三轮迭代
v_tiled = tf.tile(v, [1, caps1_n_caps, 1, 1, 1],name="caps2_output_round_2_tiled")
agreement = tf.matmul(caps2_predicted, v_tiled,transpose_a=True, name="agreement")
b = tf.add(b, agreement, name="raw_weights_round_3")
c = tf.nn.softmax(b, dim=2, name="routing_weights_round_3")
weighted_predictions = tf.multiply(c, caps2_predicted,name="weighted_predictions_round_3")
s = tf.reduce_sum(weighted_predictions, axis=1,keep_dims=True, name="weighted_sum_round_3")
v = squash(s, axis=-2, name="caps2_output_round_3")

# 定义间隔损失
m_plus = 0.9
m_minus = 0.1
lambda_ = 0.5
T = tf.one_hot(y, depth=caps2_n_caps, name="T")
v_norm = tf.norm(v, axis=-2, keep_dims=True, name="caps2_output_norm")
FP_raw = tf.square(tf.maximum(0., m_plus - v_norm), name="FP_raw")
FP = tf.reshape(FP_raw, shape=(-1, 10), name="FP")
FN_raw = tf.square(tf.maximum(0., v_norm - m_minus), name="FN_raw")
FN = tf.reshape(FN_raw, shape=(-1, 10), name="FN")
L = tf.add(T * FP, lambda_ * (1.0- T) * FN, name="L")
margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1), name="margin_loss")
# Mask 机制
y_pred = tf.placeholder(shape=[None], dtype=tf.int64, name="y_pred")
mask_with_labels = tf.placeholder_with_default(True, shape=(),name="mask_with_labels")   # train is set to True,else set to False
reconstruction_targets = tf.cond(mask_with_labels, # condition
lambda:y, # ifTrue
lambda:y_pred, # ifFalse
name="reconstruction_targets")
reconstruction_mask = tf.one_hot(reconstruction_targets,depth=caps2_n_caps,name="reconstruction_mask")
reconstruction_mask_reshaped = tf.reshape(
reconstruction_mask, [-1, 1, caps2_n_caps, 1, 1],name="reconstruction_mask_reshaped")
caps2_output_masked = tf.multiply(v, reconstruction_mask_reshaped,name="caps2_output_masked")
# 解码器
n_hidden1 = 512
n_hidden2 = 1024
n_output = 28 * 28
decoder_input = tf.reshape(caps2_output_masked,
[-1, caps2_n_caps * caps2_n_dims], name="decoder_input")
with tf.name_scope("decoder"):
    hidden1 = tf.layers.dense(decoder_input, n_hidden1,activation=tf.nn.relu,name="hidden1")
    hidden2 = tf.layers.dense(hidden1, n_hidden2,activation=tf.nn.relu,name="hidden2")
    decoder_output = tf.layers.dense(hidden2, n_output,activation=tf.nn.sigmoid,name="decoder_output")

# 重构损失
X_flat = tf.reshape(X, [-1, n_output], name="X_flat")
squared_difference = tf.square(X_flat - decoder_output,name="squared_difference")
reconstruction_loss = tf.reduce_sum(squared_difference,name="reconstruction_loss")

# 最终损失
alpha = 0.0005
loss = tf.add(margin_loss, alpha * reconstruction_loss, name="loss")
# 额外设置
# 全局初始化
init = tf.global_variables_initializer()
# saver = tf.train.Saver()

# 计算精度
print('输出预测的分类',tf.maximum(0., m_plus - v_norm))
# correct = tf.equal(y, y_pred, name="correct")
# accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")

# 用 Adam 优化器
optimizer = tf.train.AdamOptimizer()
training_op = optimizer.minimize(loss, name="training_op")

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # print(batch_ys.shape)
            # print(batch_ys[:,0].shape)
            sess.run(training_op, feed_dict={X: np.reshape(batch_xs,[-1,28,28,1]), y: batch_ys[:,0]})    # 注意feed_dict和placeholder的shape映射
        # if epoch % 5 == 0:
        #     saver.save(sess,'Capsule.ckpt')
        # acc = sess.run(accuracy, feed_dict={X: mnist.test.images, y: mnist.test.labels, y_pred:mnist.test.labels})
        # print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
        print("end")
